from __future__ import division
import math
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.patches as patches

def load_classes(path):
    """
    Loads class labels at 'path'
    """
    fp = open(path, "r")
    names = fp.read().split("\n")[:-1]
    return names

def weights_init_normal(m):
    classname = m.__class__.__name__# m为神经网络每一层的表示，m.__class__.__name__获取每层网络的名称
    if classname.find('Conv') != -1:
        # 若有卷积层
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)# 对m.weight.data进行初始化，初始化数据为均值为0，方差为0.02
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

def compute_ap(recall, precision):
    """ Compute the average precision, given the recall and precision curves.
    Code originally from https://github.com/rbgirshick/py-faster-rcnn.

    # Arguments
        recall:    The recall curve (list).
        precision: The precision curve (list).
    # Returns
        The average precision as computed in py-faster-rcnn.
    """
    # correct AP calculation
    # first append sentinel values at the end
    mrec = np.concatenate(([0.], recall, [1.]))# 用于多个数组连接
    mpre = np.concatenate(([0.], precision, [0.]))

    # compute the precision envelope
    for i in range(mpre.size - 1, 0, -1):# 从第一个开始最后一个不要
        mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])# 取一个最大的点作为积分的顶点

    # to calculate area under PR curve, look for points
    # where X axis (recall) changes value
    i = np.where(mrec[1:] != mrec[:-1])[0]# 保证相邻的两个点不同

    # and sum (\Delta recall) * prec
    ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap

def bbox_iou(box1, box2, x1y1x2y2=True):
    """
    Returns the IoU of two bounding boxes
    返回值与box的行数相同
    """
    if not x1y1x2y2:
        # Transform from center and width to exact coordinates
        # 由中心点坐标计算边缘坐标
        b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
        b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
        b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
        b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
    else:
        # Get the coordinates of bounding boxes
        # 对于传入的target和anchor box而言，x1,y1为0
        b1_x1, b1_y1, b1_x2, b1_y2 = box1[:,0], box1[:,1], box1[:,2], box1[:,3]
        b2_x1, b2_y1, b2_x2, b2_y2 = box2[:,0], box2[:,1], box2[:,2], box2[:,3]

    # get the corrdinates of the intersection rectangle
    # 对target和anchor box而言，取左边坐标的最大值，右边坐标的最小值即可计算出IOU
    inter_rect_x1 =  torch.max(b1_x1, b2_x1)# b1_x1为1*1的tensor，如果b1_x1大，那么对应位置上的输出变为b1_x1的值，反之不变
    inter_rect_y1 =  torch.max(b1_y1, b2_y1)
    inter_rect_x2 =  torch.min(b1_x2, b2_x2)
    inter_rect_y2 =  torch.min(b1_y2, b2_y2)
    # Intersection area
    inter_area =    torch.clamp(inter_rect_x2 - inter_rect_x1 + 1, min=0) * \
                    torch.clamp(inter_rect_y2 - inter_rect_y1 + 1, min=0)
    # 小于0取0，否则为原值，计算相交区域的面积
    # Union Area
    b1_area = (b1_x2 - b1_x1 + 1) * (b1_y2 - b1_y1 + 1)
    b2_area = (b2_x2 - b2_x1 + 1) * (b2_y2 - b2_y1 + 1)

    iou = inter_area / (b1_area + b2_area - inter_area + 1e-16)

    return iou


def non_max_suppression(prediction, num_classes, conf_thres=0.5, nms_thres=0.4):
    """
    Removes detections with lower object confidence score than 'conf_thres' and performs
    Non-Maximum Suppression to further filter detections.
    Returns detections with shape:
        (x1, y1, x2, y2, object_conf, class_score, class_pred)
    """

    # From (center x, center y, width, height) to (x1, y1, x2, y2)
    box_corner = prediction.new(prediction.shape)# 类似于新建一个与prediction相同shape的tensor

    # 从中心点的坐标变为左上角和右下角的坐标
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4] # 将box_corner的第0，1，2，3行赋给prediction

    output = [None for _ in range(len(prediction))]# 根据预测的个数创建list
    for image_i, image_pred in enumerate(prediction):
        # Filter out confidence scores below threshold
        # 取出大于置信度的tensor
        conf_mask = (image_pred[:, 4] >= conf_thres).squeeze()# 将本来为4*1的tensor变为size为4的tensor
        image_pred = image_pred[conf_mask]# 注意[]中的参数应该只有一个维度

        # If none are remaining => process next image
        if not image_pred.size(0):
            continue
        # Get score and class with highest confidence
        class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1,  keepdim=True)# 返回最大的class置信度和所属的class
        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
        detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
        # Iterate through all predicted classes
        unique_labels = detections[:, -1].cpu().unique()# detections最后一行是所有的class，通过unique()来只保留不同的class，为了查看识别的不同的类
        if prediction.is_cuda:
            unique_labels = unique_labels.cuda()
        for c in unique_labels:
            # Get the detections with the particular class
            detections_class = detections[detections[:, -1] == c]
            # Sort the detections by maximum objectness confidence
            _, conf_sort_index = torch.sort(detections_class[:, 4], descending=True)# 按照置信度进行排列，置信度最高的在上边，第一个返回值为排列好的tensor，第二个返回值为每一个tensor原来所在的位置
            detections_class = detections_class[conf_sort_index]# 按照降序排列
            # Perform non-maximum suppression
            max_detections = []
            while detections_class.size(0):
                # Get detection with highest confidence and save as max detection
                max_detections.append(detections_class[0].unsqueeze(0))
                # Stop if we're at the last detection
                if len(detections_class) == 1:
                    break
                # Get the IOUs for all boxes with lower confidence
                ious = bbox_iou(max_detections[-1], detections_class[1:])# 计算最大置信度的box与其他box的iou
                # Remove detections with IoU >= NMS threshold
                detections_class = detections_class[1:][ious < nms_thres]# 根据iou删除box，再取删除之后最大的继续添加到max_detections

            max_detections = torch.cat(max_detections).data
            # Add max detections to outputs
            output[image_i] = max_detections if output[image_i] is None else torch.cat((output[image_i], max_detections))# 有多个的话用cat

    return output

def build_targets(pred_boxes, target, anchors, num_anchors, num_classes, dim, ignore_thres, img_dim):
    '''
        pred_boxes:预测box的tensor值
        target:训练时标定target的值，包括中心点坐标和w,h
        anchor:原图片中的anchor在现图片中的大小
        num_anchor:anchor的个数
        num_classes:class的个数
        dim:gird的个数
        ignore_thres:值选择的概率
        img_dim:图片大小
    '''
    nB = target.size(0)# 图片的个数
    nA = num_anchors# anchor box的个数(待检测尺寸的参数)
    nC = num_classes
    dim = dim# 这个dim是图片大小
    mask        = torch.zeros(nB, nA, dim, dim)# 在哪一张图片的哪个位置选择哪个尺寸最好
    conf_mask   = torch.ones(nB, nA, dim, dim)# 通过conf_mask来实现loss中有object和无object的计算
    tx          = torch.zeros(nB, nA, dim, dim)# x距离左上角点的偏移
    ty          = torch.zeros(nB, nA, dim, dim)# y距离左上角点的偏移
    tw          = torch.zeros(nB, nA, dim, dim)# 
    th          = torch.zeros(nB, nA, dim, dim)
    tconf       = torch.zeros(nB, nA, dim, dim)# 与mask的值相同
    tcls        = torch.zeros(nB, nA, dim, dim, num_classes)# 指定最优iou的anchor box是哪一个class

    nGT = 0
    nCorrect = 0
    # 注意是传入一批图片，所以为一批图片中的不同目标
    for b in range(nB):
        for t in range(target.shape[1]):# size和shape一样， 一张图片中object的个数
            # 未成功标记的图片
            if target[b, t].sum() == 0:
                continue
            nGT += 1# 计数target
            # Convert to position relative to box
            # target在导入时进行过归一化
            gx = target[b, t, 1] * dim
            gy = target[b, t, 2] * dim
            gw = target[b, t, 3] * dim
            gh = target[b, t, 4] * dim
            # Get grid box indices
            gi = int(gx)
            gj = int(gy)
            # Get shape of gt box
            # torch.FloatTensor将array变为tensor
            gt_box = torch.FloatTensor(np.array([0, 0, gw, gh])).unsqueeze(0)# 第一维度不为1则增加一个维度，本来是3(只有一个维度)，判断第一维度不为1，增加一个维度为1*3
            # Get shape of anchor box
            anchor_shapes = torch.FloatTensor(np.concatenate((np.zeros((len(anchors), 2)), np.array(anchors)), 1))# np.concatenate将两个numpy合并
            # Calculate iou between gt and anchor shapes
            # 一个target与多个anchor box进行比较
            anch_ious = bbox_iou(gt_box, anchor_shapes)# 计算每个target对应哪个anchor box
            # Where the overlap is larger than threshold set mask to zero (ignore)
            conf_mask[b, anch_ious > ignore_thres] = 0# 将该目标对应的anchor_num下的值为0
            # Find the best matching anchor box
            best_n = np.argmax(anch_ious)# 最匹配的anchor box
            # Get ground truth box
            gt_box = torch.FloatTensor(np.array([gx, gy, gw, gh])).unsqueeze(0)# 得到target的中心点坐标和尺寸
            # Get the best prediction
            pred_box = pred_boxes[b, best_n, gj, gi].unsqueeze(0)# 最匹配的预测值
            # Masks
            mask[b, best_n, gj, gi] = 1# 第b张图片的target在哪个girds，使用哪个bounding box的尺寸最好
            conf_mask[b, best_n, gj, gi] = 1# 一堆0里面一个1其他的为全为1
            # Coordinates
            # x,y距离左上角的偏移量偏移量
            tx[b, best_n, gj, gi] = gx - gi
            ty[b, best_n, gj, gi] = gy - gj
            # Width and height
            tw[b, best_n, gj, gi] = math.log(gw/anchors[best_n][0] + 1e-16)
            th[b, best_n, gj, gi] = math.log(gh/anchors[best_n][1] + 1e-16)
            # One-hot encoding of label
            tcls[b, best_n, gj, gi, int(target[b, t, 0])] = 1# 那一份class
            # Calculate iou between ground truth and best matching prediction
            iou = bbox_iou(gt_box, pred_box, x1y1x2y2=False)# 预测值与target的iou
            tconf[b, best_n, gj, gi] = 1

            if iou > 0.5:
                nCorrect += 1# 有一个iou大于0.5那么Correct加1，代表成功预测值的个数

    return nGT, nCorrect, mask, conf_mask, tx, ty, tw, th, tconf, tcls

def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    return torch.from_numpy(np.eye(num_classes, dtype='uint8')[y])
